# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import gc

import tensorflow.compat.v2 as tf

# isort: off
from tensorflow.python.framework import (
    test_util as tf_test_utils,
)
from tensorflow.python.platform import test as test_lib

layers = tf.keras.layers
optimizers = tf.keras.optimizers


def _get_big_cnn_model(
    img_dim, n_channels, num_partitions, blocks_per_partition
):
    """Creates a test model whose activations are significantly larger than
    model size."""
    model = tf.keras.Sequential()
    model.add(layers.Input(shape=(img_dim, img_dim, n_channels)))
    for _ in range(num_partitions):
        for _ in range(blocks_per_partition):
            model.add(
                layers.Conv2D(10, 5, padding="same", activation=tf.nn.relu)
            )
            model.add(layers.MaxPooling2D((1, 1), padding="same"))
            model.add(
                layers.Conv2D(40, 5, padding="same", activation=tf.nn.relu)
            )
            model.add(layers.MaxPooling2D((1, 1), padding="same"))
            model.add(
                layers.Conv2D(20, 5, padding="same", activation=tf.nn.relu)
            )
            model.add(layers.MaxPooling2D((1, 1), padding="same"))
    model.add(layers.Flatten())
    model.add(layers.Dense(32, activation=tf.nn.relu))
    model.add(layers.Dense(10))
    return model


def _get_split_cnn_model(
    img_dim, n_channels, num_partitions, blocks_per_partition
):
    """Creates a test model that is split into `num_partitions` smaller
    models."""
    models = [tf.keras.Sequential() for _ in range(num_partitions)]
    models[0].add(layers.Input(shape=(img_dim, img_dim, n_channels)))
    for i in range(num_partitions):
        model = models[i]
        if i > 0:
            last_shape = models[i - 1].layers[-1].output_shape
            model.add(layers.Input(shape=last_shape[1:]))
        for _ in range(blocks_per_partition):
            model.add(
                layers.Conv2D(10, 5, padding="same", activation=tf.nn.relu)
            )
            model.add(layers.MaxPooling2D((1, 1), padding="same"))
            model.add(
                layers.Conv2D(40, 5, padding="same", activation=tf.nn.relu)
            )
            model.add(layers.MaxPooling2D((1, 1), padding="same"))
            model.add(
                layers.Conv2D(20, 5, padding="same", activation=tf.nn.relu)
            )
            model.add(layers.MaxPooling2D((1, 1), padding="same"))
    models[-1].add(layers.Flatten())
    models[-1].add(layers.Dense(32, activation=tf.nn.relu))
    models[-1].add(layers.Dense(10))
    return models


def _compute_loss(logits, labels):
    return tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels
        )
    )


def _limit_gpu_memory():
    """Helper function to limit GPU memory for testing."""
    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        tf.config.experimental.set_virtual_device_configuration(
            gpus[0],
            [
                tf.config.experimental.VirtualDeviceConfiguration(
                    memory_limit=2048
                )
            ],
        )
        return True
    return False


def _get_dummy_data(img_dim, n_channels, batch_size):
    inputs = tf.ones([batch_size, img_dim, img_dim, n_channels])
    labels = tf.ones([batch_size], dtype=tf.int64)
    return inputs, labels


def _train_no_recompute(n_steps):
    """Trains a single large model without gradient checkpointing."""
    img_dim, n_channels, batch_size = 256, 1, 4
    x, y = _get_dummy_data(img_dim, n_channels, batch_size)
    model = _get_big_cnn_model(
        img_dim, n_channels, num_partitions=3, blocks_per_partition=2
    )
    optimizer = optimizers.SGD()
    losses = []
    tr_vars = model.trainable_variables
    for _ in range(n_steps):
        with tf.GradientTape() as tape:
            logits = model(x)
            loss = _compute_loss(logits, y)
            losses.append(loss)
        grads = tape.gradient(loss, tr_vars)  # tr_vars
        optimizer.apply_gradients(zip(grads, tr_vars))
        del grads
    return losses


def _train_with_recompute(n_steps):
    """Trains a single large model with gradient checkpointing using
    tf.recompute_grad."""
    img_dim, n_channels, batch_size = 256, 1, 4
    x, y = _get_dummy_data(img_dim, n_channels, batch_size)
    # This model is the same model as _get_big_cnn_model but split into 3 parts.
    models = _get_split_cnn_model(
        img_dim, n_channels, num_partitions=3, blocks_per_partition=2
    )
    model1, model2, model3 = models
    # Apply gradient checkpointing to the submodels using tf.recompute_grad.
    model1_re = tf.recompute_grad(model1)
    model2_re = tf.recompute_grad(model2)
    model3_re = tf.recompute_grad(model3)
    optimizer = optimizers.SGD()
    tr_vars = (
        model1.trainable_variables
        + model2.trainable_variables
        + model3.trainable_variables
    )
    losses = []
    for _ in range(n_steps):
        with tf.GradientTape() as tape:
            logits1 = model1_re(x)
            logits2 = model2_re(logits1)
            logits3 = model3_re(logits2)
            loss = _compute_loss(logits3, y)
            losses.append(loss)
            grads = tape.gradient(loss, tr_vars)  # tr_vars
            optimizer.apply_gradients(zip(grads, tr_vars))
            del grads
    return losses


@tf_test_utils.with_eager_op_as_function
class GradientCheckpointTest(tf.test.TestCase):
    def test_raises_oom_exception(self):
        self.skipTest("b/232015009: flaky test")
        if not _limit_gpu_memory():
            self.skipTest("No virtual GPUs found")
        with self.assertRaises(Exception) as context:
            _train_no_recompute(1)
        self.assertIsInstance(
            context.exception, tf.errors.ResourceExhaustedError
        )

    @tf_test_utils.disable_xla(
        "xla does not support searching for memory-limited solvers."
    )
    def test_does_not_raise_oom_exception(self):
        if not _limit_gpu_memory():
            self.skipTest("No virtual GPUs found")
        if test_lib.is_built_with_rocm():
            self.skipTest(
                "ROCm MIOpen does not support searching for memory-limited"
                "solvers yet so skip the subtest which would result in OOM."
            )
        n_step = 2
        losses = _train_with_recompute(n_step)
        self.assertLen(losses, n_step)

    def tearDown(self):
        super().tearDown()
        # Make sure all the models created in keras has been deleted and cleared
        # from the global keras grpah, also do a force GC to recycle the GPU
        # memory.
        tf.keras.backend.clear_session()
        gc.collect()


if __name__ == "__main__":
    tf.test.main()
